/*
 * Copyright 2012 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.drools.scorecards.pmml;

import org.dmg.pmml.pmml_4_2.descr.*;
import org.drools.core.util.StringUtils;
import org.kie.pmml.pmml_4_2.extensions.PMMLExtensionNames;
import org.kie.pmml.pmml_4_2.extensions.PMMLIOAdapterMode;
import org.drools.scorecards.StringUtil;
import org.drools.scorecards.parser.xls.XLSKeywords;

import java.math.BigInteger;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Locale;

public class ScorecardPMMLGenerator {

    private static final String PMML_VERSION = "4.2.1";

    public PMML generateDocument(Scorecard pmmlScorecard) {
        //first clean up the scorecard
        removeEmptyExtensions(pmmlScorecard);
        createAndSetPredicates(pmmlScorecard);

        //second add additional elements to scorecard
        createAndSetOutput(pmmlScorecard);
        repositionExternalClassExtensions(pmmlScorecard);

        Extension scorecardPackage = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.MODEL_PACKAGE );
        if ( scorecardPackage != null) {
            pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(scorecardPackage);
        }
        Extension importsExt = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.MODEL_IMPORTS );
        if ( importsExt != null) {
            pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(importsExt);
        }
        Extension agendaGroupExt = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.AGENDA_GROUP );
        if ( agendaGroupExt != null) {
            pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(agendaGroupExt);
        }
        Extension ruleFlowGroupExt = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.RULEFLOW_GROUP);
        if ( ruleFlowGroupExt != null) {
            pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(ruleFlowGroupExt);
        }

        //now create the PMML document
        PMML pmml = new PMML();
        pmml.setVersion(PMML_VERSION);
        Header header = new Header();
        Timestamp timestamp = new Timestamp();
        timestamp.getContent().add(new SimpleDateFormat("yyyy.MM.dd 'at' HH:mm:ss z", Locale.ENGLISH).format(new Date()));
        header.setTimestamp(timestamp);
        header.setDescription("generated by the drools-scorecards module");
        header.getExtensions().add(scorecardPackage);
        header.getExtensions().add(importsExt);

        if (ruleFlowGroupExt != null){
            header.getExtensions().add(ruleFlowGroupExt);
        }
        if (agendaGroupExt != null){
            header.getExtensions().add(agendaGroupExt);
        }
        pmml.setHeader(header);

        createAndSetDataDictionary(pmml, pmmlScorecard);
        pmml.getAssociationModelsAndBaselineModelsAndClusteringModels().add(pmmlScorecard);
        removeAttributeFieldExtension(pmmlScorecard);
        return pmml;
    }

    private void repositionExternalClassExtensions(Scorecard pmmlScorecard) {
        Characteristics characteristics = null;
        for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
            if ( obj instanceof  Characteristics ) {
                characteristics = (Characteristics) obj;
                break;
            }
        }
        for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
            if ( obj instanceof MiningSchema ) {
                MiningSchema schema = (MiningSchema)obj;
                    Extension adapter = new Extension();
                    adapter.setName( PMMLExtensionNames.IO_ADAPTER );
                    adapter.setValue( PMMLIOAdapterMode.BEAN.name() );
                    schema.getExtensions().add( adapter );
                for (MiningField miningField : schema.getMiningFields()) {
                    String fieldName = miningField.getName();
                    for (Characteristic characteristic : characteristics.getCharacteristics()){
                        String characteristicName = ScorecardPMMLUtils.extractFieldNameFromCharacteristic(characteristic);
                        if (fieldName.equalsIgnoreCase(characteristicName)){
                            Extension extension = ScorecardPMMLUtils.getExtension(characteristic.getExtensions(), PMMLExtensionNames.EXTERNAL_CLASS );
                            if ( extension != null ) {
                                characteristic.getExtensions().remove(extension);
                                if ( ScorecardPMMLUtils.getExtension(miningField.getExtensions(), PMMLExtensionNames.EXTERNAL_CLASS ) == null ) {
                                    miningField.getExtensions().add(extension);
                                }
                            }
                        }
                    }
                }
                MiningField targetField = new MiningField();
                targetField.setName( ScorecardPMMLExtensionNames.DEFAULT_PREDICTED_FIELD );
                targetField.setUsageType( FIELDUSAGETYPE.PREDICTED );
                schema.getMiningFields().add( targetField );
            } else if ( obj instanceof Output ) {
                Extension adapter = new Extension();
                adapter.setName( PMMLExtensionNames.IO_ADAPTER );
                adapter.setValue( PMMLIOAdapterMode.BEAN.name() );
                ( (Output) obj ).getExtensions().add( adapter );

            }
        }
    }

    private void removeAttributeFieldExtension(Scorecard pmmlScorecard) {
        for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
            if (obj instanceof Characteristics) {
                Characteristics characteristics = (Characteristics) obj;
                for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {
                    for (Attribute attribute : characteristic.getAttributes()) {
                        Extension fieldExtension = ScorecardPMMLUtils.getExtension(attribute.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD);
                        if ( fieldExtension != null ) {
                            attribute.getExtensions().remove(fieldExtension);
                            //break;
                        }
                    }
                }
            }
        }
    }

    private void createAndSetDataDictionary(PMML pmml, Scorecard pmmlScorecard) {

        DataDictionary dataDictionary = new DataDictionary();
        pmml.setDataDictionary(dataDictionary);
        int ctr = 0;
        for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
            if (obj instanceof Characteristics) {
                Characteristics characteristics = (Characteristics) obj;
                for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {

                    DataField dataField = new DataField();
                    Extension dataTypeExtension = ScorecardPMMLUtils.getExtension(characteristic.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_DATATYPE);
                    String dataType = dataTypeExtension.getValue();
                    String factType = ScorecardPMMLUtils.getExtensionValue(characteristic.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_FACTTYPE);

                    if ( factType != null ){
                        Extension extension = new Extension();
                        extension.setName("FactType");
                        extension.setValue(factType);
                        dataField.getExtensions().add(extension);
                    }


                    if (XLSKeywords.DATATYPE_DOUBLE.equalsIgnoreCase(dataType)) {
                        dataField.setDataType(DATATYPE.DOUBLE);
                        dataField.setOptype(OPTYPE.CONTINUOUS);
                    } else if (XLSKeywords.DATATYPE_INTEGER.equalsIgnoreCase(dataType)) {
                        dataField.setDataType(DATATYPE.INTEGER);
                        dataField.setOptype(OPTYPE.CONTINUOUS);
                    } else if (XLSKeywords.DATATYPE_NUMBER.equalsIgnoreCase(dataType)) {
                        dataField.setDataType(DATATYPE.DOUBLE);
                        dataField.setOptype(OPTYPE.CONTINUOUS);
                    } else if (XLSKeywords.DATATYPE_TEXT.equalsIgnoreCase(dataType)) {
                        dataField.setDataType(DATATYPE.STRING);
                        dataField.setOptype(OPTYPE.CATEGORICAL);
                    } else if (XLSKeywords.DATATYPE_BOOLEAN.equalsIgnoreCase(dataType)) {
                        dataField.setDataType(DATATYPE.BOOLEAN);
                        dataField.setOptype(OPTYPE.CATEGORICAL);
                    }
                    String field = "";
                    for (Attribute attribute : characteristic.getAttributes()) {
                        for (Extension extension : attribute.getExtensions()) {
                            if ( ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD.equalsIgnoreCase(extension.getName())) {
                                field = extension.getValue();
                                break;
                            }//
                        }
                    }
                    dataField.setName(field);
                    dataDictionary.getDataFields().add(dataField);
                    characteristic.getExtensions().remove(dataTypeExtension);
                    ctr++;
                }
            }
        }
        DataField targetField = new DataField();
        targetField.setName( ScorecardPMMLExtensionNames.DEFAULT_PREDICTED_FIELD );
        targetField.setDataType( DATATYPE.DOUBLE );
        targetField.setOptype( OPTYPE.CONTINUOUS );
        dataDictionary.getDataFields().add( targetField );
        dataDictionary.setNumberOfFields(BigInteger.valueOf(ctr + 1));
    }

    private void createAndSetOutput(Scorecard pmmlScorecard) {
        Extension classExtension = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.EXTERNAL_CLASS);
        Extension fieldExtension = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_RESULTANT_SCORE_FIELD);
        Extension reasonCodeExtension = ScorecardPMMLUtils.getExtension(pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_RESULTANT_REASONCODES_FIELD);
        for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
            if (obj instanceof Output) {
                Output output = (Output)obj;

                OutputField outputField = new OutputField();
                outputField.setDataType(DATATYPE.DOUBLE);
                outputField.setFeature(RESULTFEATURE.PREDICTED_VALUE);
                outputField.setDisplayName("Final Score");
                if ( fieldExtension != null ) {
                    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(fieldExtension);
                    outputField.setName(fieldExtension.getValue());
                } else {
                    outputField.setName( "calculatedScore" );
                }

                if ( classExtension != null ) {
                    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(classExtension);
                    outputField.getExtensions().add( classExtension );
                }
                output.getOutputFields().add(outputField);

                if ( pmmlScorecard.getUseReasonCodes() ) {
                    OutputField reasonCodeField = new OutputField();
                    reasonCodeField.setDataType( DATATYPE.STRING );
                    reasonCodeField.setFeature( RESULTFEATURE.REASON_CODE );
                    reasonCodeField.setDisplayName( "Principal Reason Code" );

                    if ( reasonCodeExtension != null ) {
                        pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().remove(reasonCodeExtension);
                        reasonCodeField.getExtensions().add( classExtension );
                        reasonCodeField.setName( reasonCodeExtension.getValue() );
                    } else {
                        reasonCodeField.setName( "reasonCode" );
                    }
                    output.getOutputFields().add( reasonCodeField );
                }


                break;
            }
        }
    }

    private void createAndSetPredicates(Scorecard pmmlScorecard) {
        for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
            if (obj instanceof Characteristics) {
                Characteristics characteristics = (Characteristics) obj;
                for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {
                    String dataType = ScorecardPMMLUtils.getExtensionValue(characteristic.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_DATATYPE);
                    Extension predicateExtension = null;
                    for (Attribute attribute : characteristic.getAttributes()) {
                        String predicateAsString = "";
                        String field = ScorecardPMMLUtils.getExtensionValue(attribute.getExtensions(), ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD);
                        for (Extension extension : attribute.getExtensions()) {
                            if ("predicateResolver".equalsIgnoreCase(extension.getName())) {
                                predicateAsString = extension.getValue();
                                predicateExtension = extension;
                                break;
                            }
                        }
                        setPredicatesForAttribute(attribute, dataType, field, predicateAsString);
                        attribute.getExtensions().remove(predicateExtension);
                    }
                }
            }
        }
    }

    private void setPredicatesForAttribute(Attribute pmmlAttribute, String dataType, String field, String predicateAsString) {
        predicateAsString = StringUtil.unescapeXML(predicateAsString);
        if (XLSKeywords.DATATYPE_NUMBER.equalsIgnoreCase(dataType) ||
                XLSKeywords.DATATYPE_DOUBLE.equalsIgnoreCase(dataType) ||
                XLSKeywords.DATATYPE_INTEGER.equalsIgnoreCase(dataType)) {
            setNumericPredicate(pmmlAttribute, field, predicateAsString);
        } else if (XLSKeywords.DATATYPE_TEXT.equalsIgnoreCase(dataType)) {
            setTextPredicate(pmmlAttribute, field, predicateAsString);
        } else if (XLSKeywords.DATATYPE_BOOLEAN.equalsIgnoreCase(dataType)) {
            setBooleanPredicate(pmmlAttribute, field, predicateAsString);
        }
    }

    private void setBooleanPredicate(Attribute pmmlAttribute, String field, String predicateAsString) {
        SimplePredicate simplePredicate = new SimplePredicate();
        simplePredicate.setField(field);
        simplePredicate.setOperator(PMMLOperators.EQUAL);
        if ("TRUE".equalsIgnoreCase(predicateAsString)){
            simplePredicate.setValue("TRUE");
        } else if ("FALSE".equalsIgnoreCase(predicateAsString)){
            simplePredicate.setValue("FALSE");
        }
        pmmlAttribute.setSimplePredicate(simplePredicate);
    }

    private void setTextPredicate(Attribute pmmlAttribute, String field, String predicateAsString) {
        String operator = "";
        if (predicateAsString.startsWith("=")) {
            operator = "=";
            predicateAsString = predicateAsString.substring(1);
        } else if (predicateAsString.startsWith("!=")) {
            operator = "!=";
            predicateAsString = predicateAsString.substring(2);
        }
        if (predicateAsString.contains(",")) {
            SimpleSetPredicate simpleSetPredicate = new SimpleSetPredicate();
            if ("!=".equalsIgnoreCase(operator)) {
                simpleSetPredicate.setBooleanOperator(PMMLOperators.IS_NOT_IN);
            } else {
                simpleSetPredicate.setBooleanOperator(PMMLOperators.IS_IN);
            }
            simpleSetPredicate.setField(field);
            predicateAsString = predicateAsString.trim();
            if  (predicateAsString.endsWith(",")) {
                predicateAsString = predicateAsString.substring(0, predicateAsString.length()-1);
            }
            Array array = new Array();
            array.setContent(predicateAsString.replace(",", " "));
            array.setType("string");
            array.setN(BigInteger.valueOf(predicateAsString.split(",").length));
            simpleSetPredicate.setArray(array);
            pmmlAttribute.setSimpleSetPredicate(simpleSetPredicate);
        } else {
            SimplePredicate simplePredicate = new SimplePredicate();
            simplePredicate.setField(field);
            if ("!=".equalsIgnoreCase(operator)) {
                simplePredicate.setOperator(PMMLOperators.NOT_EQUAL);
            } else {
                simplePredicate.setOperator(PMMLOperators.EQUAL);
            }
            simplePredicate.setValue(predicateAsString);
            pmmlAttribute.setSimplePredicate(simplePredicate);
        }
    }

    private void setNumericPredicate(Attribute pmmlAttribute, String field, String predicateAsString) {
        if (predicateAsString.indexOf("-") > 0) {
            CompoundPredicate compoundPredicate = new CompoundPredicate();
            compoundPredicate.setBooleanOperator("and");
            String left = predicateAsString.substring(0, predicateAsString.indexOf("-")).trim();
            String right = predicateAsString.substring(predicateAsString.indexOf("-") + 1).trim();
            SimplePredicate simplePredicate = new SimplePredicate();
            simplePredicate.setField(field);
            simplePredicate.setOperator(PMMLOperators.GREATER_OR_EQUAL);
            simplePredicate.setValue(left);
            compoundPredicate.getSimplePredicatesAndCompoundPredicatesAndSimpleSetPredicates().add(simplePredicate);
            simplePredicate = new SimplePredicate();
            simplePredicate.setField(field);
            simplePredicate.setOperator(PMMLOperators.LESS_THAN);
            simplePredicate.setValue(right);
            compoundPredicate.getSimplePredicatesAndCompoundPredicatesAndSimpleSetPredicates().add(simplePredicate);
            pmmlAttribute.setCompoundPredicate(compoundPredicate);
        } else {
            SimplePredicate simplePredicate = new SimplePredicate();
            simplePredicate.setField(field);
            if (predicateAsString.startsWith("<=")) {
                simplePredicate.setOperator(PMMLOperators.LESS_OR_EQUAL);
                simplePredicate.setValue(predicateAsString.substring(3).trim());
            } else if (predicateAsString.startsWith(">=")) {
                simplePredicate.setOperator(PMMLOperators.GREATER_OR_EQUAL);
                simplePredicate.setValue(predicateAsString.substring(3).trim());
            } else if (predicateAsString.startsWith("=")) {
                simplePredicate.setOperator(PMMLOperators.EQUAL);
                simplePredicate.setValue(predicateAsString.substring(2).trim());
            } else if (predicateAsString.startsWith("!=")) {
                simplePredicate.setOperator(PMMLOperators.NOT_EQUAL);
                simplePredicate.setValue(predicateAsString.substring(3).trim());
            } else if (predicateAsString.startsWith("<")) {
                simplePredicate.setOperator(PMMLOperators.LESS_THAN);
                simplePredicate.setValue(predicateAsString.substring(2).trim());
            } else if (predicateAsString.startsWith(">")) {
                simplePredicate.setOperator(PMMLOperators.GREATER_THAN);
                simplePredicate.setValue(predicateAsString.substring(2).trim());
            }
            pmmlAttribute.setSimplePredicate(simplePredicate);
        }
    }

    private void removeEmptyExtensions(Scorecard pmmlScorecard) {
        for (Object obj : pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas()) {
            if (obj instanceof Characteristics) {
                Characteristics characteristics = (Characteristics) obj;
                for (org.dmg.pmml.pmml_4_2.descr.Characteristic characteristic : characteristics.getCharacteristics()) {
                    List<Extension> toRemoveExtensionsList = new ArrayList<Extension>();
                    for (Extension extension : characteristic.getExtensions()) {
                        if (StringUtils.isEmpty(extension.getValue())) {
                            toRemoveExtensionsList.add(extension);
                        }
                    }
                    for (Extension extension : toRemoveExtensionsList) {
                        characteristic.getExtensions().remove(extension);
                    }

                    for (Attribute attribute : characteristic.getAttributes()) {
                        List<Extension> toRemoveExtensionsList2 = new ArrayList<Extension>();
                        for (Extension extension : attribute.getExtensions()) {
                            if (StringUtils.isEmpty(extension.getValue())) {
                                toRemoveExtensionsList2.add(extension);
                            }
                        }
                        for (Extension extension : toRemoveExtensionsList2) {
                            attribute.getExtensions().remove(extension);
                        }
                    }
                }
            }
        }
    }

}
